{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Robust Question Answering with Chroma and OpenAI \n",
    "\n",
    "This notebook guides you step-by-step through answering questions about a collection of data, using [Chroma](https://trychroma.com), an open-source embeddings database, along with OpenAI's [text embeddings](https://platform.openai.com/docs/guides/embeddings/use-cases) and [chat completion](https://platform.openai.com/docs/guides/chat) API's. \n",
    "\n",
    "Additionally, this notebook demonstrates some of the tradeoffs in making a question answering system more robust. As we shall see, *simple querying doesn't always create the best results*! \n",
    "\n",
    "## Question Answering with LLMs\n",
    "\n",
    "Large language models (LLMs) like OpenAI's ChatGPT can be used to answer questions about data that the model may not have been trained on, or have access to. For example;\n",
    "\n",
    "- Personal data like e-mails and notes\n",
    "- Highly specialized data like archival or legal documents\n",
    "- Newly created data like recent news stories\n",
    "\n",
    "In order to overcome this limitation, we can use a data store which is amenable to querying in natural language, just like the LLM itself. An embeddings store like Chroma represents documents as [embeddings](https://openai.com/blog/introducing-text-and-code-embeddings), alongside the documents themselves. \n",
    "\n",
    "By embedding a text query, Chroma can find relevant documents, which we can then pass to the LLM to answer our question. We'll show detailed examples and variants of this approach. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup and preliminaries\n",
    "\n",
    "First we make sure the python dependencies we need are installed. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Note: you may need to restart the kernel to use updated packages.\n"
     ]
    }
   ],
   "source": [
    "%pip install -qU openai chromadb pandas"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We use OpenAI's API's throughout this notebook. You can get an API key from [https://beta.openai.com/account/api-keys](https://beta.openai.com/account/api-keys)\n",
    "\n",
    "You can add your API key as an environment variable by executing the command `export OPENAI_API_KEY=sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx` in a terminal. Note that you will need to reload the notebook if the environment variable wasn't set yet. Alternatively, you can set it in the notebook, see below. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "OPENAI_API_KEY is ready\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "\n",
    "# Uncomment the following line to set the environment variable in the notebook\n",
    "# os.environ[\"OPENAI_API_KEY\"] = 'sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx'\n",
    "\n",
    "if os.getenv(\"OPENAI_API_KEY\") is not None:\n",
    "    print(\"OPENAI_API_KEY is ready\")\n",
    "    import openai\n",
    "    openai.api_key = os.getenv(\"OPENAI_API_KEY\")\n",
    "else:\n",
    "    print(\"OPENAI_API_KEY environment variable not found\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Dataset\n",
    "\n",
    "Throughout this notebook, we use the [SciFact dataset](https://github.com/allenai/scifact). This is a curated dataset of expert annotated scientific claims, with an accompanying text corpus of paper titles and abstracts. Each claim may be supported, contradicted, or not have enough evidence either way, according to the documents in the corpus. \n",
    "\n",
    "Having the corpus available as ground-truth allows us to investigate how well the following approaches to LLM question answering perform. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>id</th>\n",
       "      <th>claim</th>\n",
       "      <th>evidence</th>\n",
       "      <th>cited_doc_ids</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>0-dimensional biomaterials show inductive prop...</td>\n",
       "      <td>{}</td>\n",
       "      <td>[31715818]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>3</td>\n",
       "      <td>1,000 genomes project enables mapping of genet...</td>\n",
       "      <td>{'14717500': [{'sentences': [2, 5], 'label': '...</td>\n",
       "      <td>[14717500]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>5</td>\n",
       "      <td>1/2000 in UK have abnormal PrP positivity.</td>\n",
       "      <td>{'13734012': [{'sentences': [4], 'label': 'SUP...</td>\n",
       "      <td>[13734012]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>13</td>\n",
       "      <td>5% of perinatal mortality is due to low birth ...</td>\n",
       "      <td>{}</td>\n",
       "      <td>[1606628]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>36</td>\n",
       "      <td>A deficiency of vitamin B12 increases blood le...</td>\n",
       "      <td>{}</td>\n",
       "      <td>[5152028, 11705328]</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   id                                              claim  \\\n",
       "0   1  0-dimensional biomaterials show inductive prop...   \n",
       "1   3  1,000 genomes project enables mapping of genet...   \n",
       "2   5         1/2000 in UK have abnormal PrP positivity.   \n",
       "3  13  5% of perinatal mortality is due to low birth ...   \n",
       "4  36  A deficiency of vitamin B12 increases blood le...   \n",
       "\n",
       "                                            evidence        cited_doc_ids  \n",
       "0                                                 {}           [31715818]  \n",
       "1  {'14717500': [{'sentences': [2, 5], 'label': '...           [14717500]  \n",
       "2  {'13734012': [{'sentences': [4], 'label': 'SUP...           [13734012]  \n",
       "3                                                 {}            [1606628]  \n",
       "4                                                 {}  [5152028, 11705328]  "
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Load the claim dataset\n",
    "import pandas as pd\n",
    "\n",
    "data_path = '../../data'\n",
    "\n",
    "claim_df = pd.read_json(f'{data_path}/scifact_claims.jsonl', lines=True)\n",
    "claim_df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Just asking the model\n",
    "\n",
    "GPT-3.5 was trained on a large amount of scientific information. As a baseline, we'd like to understand what the model already knows without any further context. This will allow us to calibrate overall performance. \n",
    "\n",
    "We construct an appropriate prompt, with some example facts, then query the model with each claim in the dataset. We ask the model to assess a claim as 'True', 'False', or 'NEE' if there is not enough evidence one way or the other. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_prompt(claim):\n",
    "    return [\n",
    "        {\"role\": \"system\", \"content\": \"I will ask you to assess a scientific claim. Output only the text 'True' if the claim is true, 'False' if the claim is false, or 'NEE' if there's not enough evidence.\"},\n",
    "        {\"role\": \"user\", \"content\": f\"\"\"        \n",
    "Example:\n",
    "\n",
    "Claim:\n",
    "0-dimensional biomaterials show inductive properties.\n",
    "\n",
    "Assessment:\n",
    "False\n",
    "\n",
    "Claim:\n",
    "1/2000 in UK have abnormal PrP positivity.\n",
    "\n",
    "Assessment:\n",
    "True\n",
    "\n",
    "Claim:\n",
    "Aspirin inhibits the production of PGE2.\n",
    "\n",
    "Assessment:\n",
    "False\n",
    "\n",
    "End of examples. Assess the following claim:\n",
    "\n",
    "Claim:\n",
    "{claim}\n",
    "\n",
    "Assessment:\n",
    "\"\"\"}\n",
    "    ]\n",
    "\n",
    "\n",
    "def assess_claims(claims):\n",
    "    responses = []\n",
    "    # Query the OpenAI API\n",
    "    for claim in claims:\n",
    "        response = openai.ChatCompletion.create(\n",
    "            model='gpt-3.5-turbo',\n",
    "            messages=build_prompt(claim),\n",
    "            max_tokens=3,\n",
    "        )\n",
    "        # Strip any punctuation or whitespace from the response\n",
    "        responses.append(response.choices[0].message.content.strip('., '))\n",
    "\n",
    "    return responses"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We sample 100 claims from the dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Let's take a look at 100 claims\n",
    "samples = claim_df.sample(50)\n",
    "\n",
    "claims = samples['claim'].tolist() \n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We evaluate the ground-truth according to the dataset. From the dataset description, each claim is either supported or contradicted by the evidence, or else there isn't enough evidence either way. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_groundtruth(evidence):\n",
    "    groundtruth = []\n",
    "    for e in evidence:\n",
    "        # Evidence is empty \n",
    "        if len(e) == 0:\n",
    "            groundtruth.append('NEE')\n",
    "        else:\n",
    "            # In this dataset, all evidence for a given claim is consistent, either SUPPORT or CONTRADICT\n",
    "            if list(e.values())[0][0]['label'] == 'SUPPORT':\n",
    "                groundtruth.append('True')\n",
    "            else:\n",
    "                groundtruth.append('False')\n",
    "    return groundtruth"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "evidence = samples['evidence'].tolist()\n",
    "groundtruth = get_groundtruth(evidence)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We also output the confusion matrix, comparing the model's assessments with the ground truth, in an easy to read table. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def confusion_matrix(inferred, groundtruth):\n",
    "    assert len(inferred) == len(groundtruth)\n",
    "    confusion = {\n",
    "        'True': {'True': 0, 'False': 0, 'NEE': 0},\n",
    "        'False': {'True': 0, 'False': 0, 'NEE': 0},\n",
    "        'NEE': {'True': 0, 'False': 0, 'NEE': 0},\n",
    "    }\n",
    "    for i, g in zip(inferred, groundtruth):\n",
    "        confusion[i][g] += 1\n",
    "\n",
    "    # Pretty print the confusion matrix\n",
    "    print('\\tGroundtruth')\n",
    "    print('\\tTrue\\tFalse\\tNEE')\n",
    "    for i in confusion:\n",
    "        print(i, end='\\t')\n",
    "        for g in confusion[i]:\n",
    "            print(confusion[i][g], end='\\t')\n",
    "        print()\n",
    "\n",
    "    return confusion"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We ask the model to directly assess the claims, without additional context. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\tGroundtruth\n",
      "\tTrue\tFalse\tNEE\n",
      "True\t15\t5\t14\t\n",
      "False\t0\t2\t1\t\n",
      "NEE\t3\t3\t7\t\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'True': {'True': 15, 'False': 5, 'NEE': 14},\n",
       " 'False': {'True': 0, 'False': 2, 'NEE': 1},\n",
       " 'NEE': {'True': 3, 'False': 3, 'NEE': 7}}"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gpt_inferred = assess_claims(claims)\n",
    "confusion_matrix(gpt_inferred, groundtruth)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Results\n",
    "\n",
    "From these results we see that the LLM is strongly biased to assess claims as true, even when they are false, and also tends to assess false claims as not having enough evidence. Note that 'not enough evidence' is with respect to the model's assessment of the claim in a vacuum, without additional context.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Adding context \n",
    "\n",
    "We now add the additional context available from the corpus of paper titles and abstracts. This section shows how to load a text corpus into Chroma, using OpenAI text embeddings. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First, we load the text corpus. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>doc_id</th>\n",
       "      <th>title</th>\n",
       "      <th>abstract</th>\n",
       "      <th>structured</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>4983</td>\n",
       "      <td>Microstructural development of human newborn c...</td>\n",
       "      <td>[Alterations of the architecture of cerebral w...</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>5836</td>\n",
       "      <td>Induction of myelodysplasia by myeloid-derived...</td>\n",
       "      <td>[Myelodysplastic syndromes (MDS) are age-depen...</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>7912</td>\n",
       "      <td>BC1 RNA, the transcript from a master gene for...</td>\n",
       "      <td>[ID elements are short interspersed elements (...</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>18670</td>\n",
       "      <td>The DNA Methylome of Human Peripheral Blood Mo...</td>\n",
       "      <td>[DNA methylation plays an important role in bi...</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>19238</td>\n",
       "      <td>The human myelin basic protein gene is include...</td>\n",
       "      <td>[Two human Golli (for gene expressed in the ol...</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   doc_id                                              title  \\\n",
       "0    4983  Microstructural development of human newborn c...   \n",
       "1    5836  Induction of myelodysplasia by myeloid-derived...   \n",
       "2    7912  BC1 RNA, the transcript from a master gene for...   \n",
       "3   18670  The DNA Methylome of Human Peripheral Blood Mo...   \n",
       "4   19238  The human myelin basic protein gene is include...   \n",
       "\n",
       "                                            abstract  structured  \n",
       "0  [Alterations of the architecture of cerebral w...       False  \n",
       "1  [Myelodysplastic syndromes (MDS) are age-depen...       False  \n",
       "2  [ID elements are short interspersed elements (...       False  \n",
       "3  [DNA methylation plays an important role in bi...       False  \n",
       "4  [Two human Golli (for gene expressed in the ol...       False  "
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Load the corpus into a dataframe\n",
    "corpus_df = pd.read_json(f'{data_path}/scifact_corpus.jsonl', lines=True)\n",
    "corpus_df.head()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loading the corpus into Chroma\n",
    "\n",
    "The next step is to load the corpus into Chroma. Given an embedding function, Chroma will automatically handle embedding each document, and will store it alongside its text and metadata, making it simple to query."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We instantiate a (ephemeral) Chroma client, and create a collection for the SciFact title and abstract corpus. \n",
    "Chroma can also be instantiated in a persisted configuration; learn more at the [Chroma docs](https://docs.trychroma.com/usage-guide?lang=py). "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running Chroma using direct local API.\n",
      "Using DuckDB in-memory for database. Data will be transient.\n"
     ]
    }
   ],
   "source": [
    "import chromadb\n",
    "from chromadb.utils.embedding_functions import OpenAIEmbeddingFunction\n",
    "\n",
    "# We initialize an embedding function, and provide it to the collection.\n",
    "embedding_function = OpenAIEmbeddingFunction(api_key=os.getenv(\"OPENAI_API_KEY\"))\n",
    "\n",
    "chroma_client = chromadb.Client() # Ephemeral by default\n",
    "scifact_corpus_collection = chroma_client.create_collection(name='scifact_corpus', embedding_function=embedding_function)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next we load the corpus into Chroma. Because this data loading is memory intensive, we recommend using a batched loading scheme in batches of 50-1000. For this example it should take just over one minute for the entire corpus. It's being embedded in the background, automatically, using the `embedding_function` we specified earlier."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 100\n",
    "\n",
    "for i in range(0, len(corpus_df), batch_size):\n",
    "    batch_df = corpus_df[i:i+batch_size]\n",
    "    scifact_corpus_collection.add(\n",
    "        ids=batch_df['doc_id'].apply(lambda x: str(x)).tolist(), # Chroma takes string IDs.\n",
    "        documents=(batch_df['title'] + '. ' + batch_df['abstract'].apply(lambda x: ' '.join(x))).to_list(), # We concatenate the title and abstract.\n",
    "        metadatas=[{\"structured\": structured} for structured in batch_df['structured'].to_list()] # We also store the metadata, though we don't use it in this example.\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Retrieving context\n",
    "\n",
    "Next we retrieve documents from the corpus which may be relevant to each claim in our sample. We want to provide these as context to the LLM for evaluating the claims. We retrieve the 3 most relevant documents for each claim, according to the embedding distance. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "claim_query_result = scifact_corpus_collection.query(query_texts=claims, include=['documents', 'distances'], n_results=3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We create a new prompt, this time taking into account the additional context we retrieve from the corpus. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_prompt_with_context(claim, context):\n",
    "    return [{'role': 'system', 'content': \"I will ask you to assess whether a particular scientific claim, based on evidence provided. Output only the text 'True' if the claim is true, 'False' if the claim is false, or 'NEE' if there's not enough evidence.\"}, \n",
    "            {'role': 'user', 'content': f\"\"\"\"\n",
    "The evidence is the following:\n",
    "\n",
    "{' '.join(context)}\n",
    "\n",
    "Assess the following claim on the basis of the evidence. Output only the text 'True' if the claim is true, 'False' if the claim is false, or 'NEE' if there's not enough evidence. Do not output any other text. \n",
    "\n",
    "Claim:\n",
    "{claim}\n",
    "\n",
    "Assessment:\n",
    "\"\"\"}]\n",
    "\n",
    "\n",
    "def assess_claims_with_context(claims, contexts):\n",
    "    responses = []\n",
    "    # Query the OpenAI API\n",
    "    for claim, context in zip(claims, contexts):\n",
    "        # If no evidence is provided, return NEE\n",
    "        if len(context) == 0:\n",
    "            responses.append('NEE')\n",
    "            continue\n",
    "        response = openai.ChatCompletion.create(\n",
    "            model='gpt-3.5-turbo',\n",
    "            messages=build_prompt_with_context(claim=claim, context=context),\n",
    "            max_tokens=3,\n",
    "        )\n",
    "        # Strip any punctuation or whitespace from the response\n",
    "        responses.append(response.choices[0].message.content.strip('., '))\n",
    "\n",
    "    return responses"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then ask the model to evaluate the claims with the retrieved context. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\tGroundtruth\n",
      "\tTrue\tFalse\tNEE\n",
      "True\t16\t2\t8\t\n",
      "False\t1\t6\t5\t\n",
      "NEE\t1\t2\t9\t\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'True': {'True': 16, 'False': 2, 'NEE': 8},\n",
       " 'False': {'True': 1, 'False': 6, 'NEE': 5},\n",
       " 'NEE': {'True': 1, 'False': 2, 'NEE': 9}}"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gpt_with_context_evaluation = assess_claims_with_context(claims, claim_query_result['documents'])\n",
    "confusion_matrix(gpt_with_context_evaluation, groundtruth)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Results\n",
    "\n",
    "We see that the model is a lot less likely to evaluate a False claim as true (2 instances VS 5 previously), but that claims without enough evidence are still often assessed as True or False.\n",
    "\n",
    "Taking a look at the retrieved documents, we see that they are sometimes not relevant to the claim - this causes the model to be confused by the extra information, and it may decide that sufficient evidence is present, even when the information is irrelevant. This happens because we always ask for the 3 'most' relevant documents, but these might not be relevant at all beyond a certain point.  "
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Filtering context on relevance\n",
    "\n",
    "Along with the documents themselves, Chroma returns a distance score. We can try thresholding on distance, so that fewer irrelevant documents make it into the context we provide the model. \n",
    "\n",
    "If, after filtering on the threshold, no context documents remain, we bypass the model and simply return that there is not enough evidence. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "def filter_query_result(query_result, distance_threshold=0.25):\n",
    "# For each query result, retain only the documents whose distance is below the threshold\n",
    "    for ids, docs, distances in zip(query_result['ids'], query_result['documents'], query_result['distances']):\n",
    "        for i in range(len(ids)-1, -1, -1):\n",
    "            if distances[i] > distance_threshold:\n",
    "                ids.pop(i)\n",
    "                docs.pop(i)\n",
    "                distances.pop(i)\n",
    "    return query_result\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "filtered_claim_query_result = filter_query_result(claim_query_result)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we assess the claims using this cleaner context. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\tGroundtruth\n",
      "\tTrue\tFalse\tNEE\n",
      "True\t10\t2\t1\t\n",
      "False\t0\t2\t1\t\n",
      "NEE\t8\t6\t20\t\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'True': {'True': 10, 'False': 2, 'NEE': 1},\n",
       " 'False': {'True': 0, 'False': 2, 'NEE': 1},\n",
       " 'NEE': {'True': 8, 'False': 6, 'NEE': 20}}"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gpt_with_filtered_context_evaluation = assess_claims_with_context(claims, filtered_claim_query_result['documents'])\n",
    "confusion_matrix(gpt_with_filtered_context_evaluation, groundtruth)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Results\n",
    "\n",
    "The model now assesses many fewer claims as True or False when there is not enough evidence present. However, it now biases away from certainty. Most claims are now assessed as having not enough evidence, because a large fraction of them are filtered out by the distance threshold. It's possible to tune the distance threshold to find the optimal operating point, but this can be difficult, and is dataset and embedding model dependent. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Hypothetical Document Embeddings: Using hallucinations productively\n",
    "\n",
    "We want to be able to retrieve relevant documents, without retrieving less relevant ones which might confuse the model. One way to accomplish this is to improve the retrieval query. \n",
    "\n",
    "Until now, we have queried the dataset using _claims_ which are single sentence statements, while the corpus contains _abstracts_ describing a scientific paper. Intuitively, while these might be related, there are significant differences in their structure and meaning. These differences are encoded by the embedding model, and so influence the distances between the query and the most relevant results. \n",
    "\n",
    "We can overcome this by leveraging the power of LLMs to generate relevant text. While the facts might be hallucinated, the content and structure of the documents the models generate is more similar to the documents in our corpus, than the queries are. This could lead to better queries and hence better results. \n",
    "\n",
    "This approach is called [Hypothetical Document Embeddings (HyDE)](https://arxiv.org/abs/2212.10496), and has been shown to be quite good at the retrieval task. It should help us bring more relevant information into the context, without polluting it.  \n",
    "\n",
    "TL;DR:\n",
    "- you get much better matches when you embed whole abstracts rather than single sentences\n",
    "- but claims are usually single sentences\n",
    "- So HyDE shows that using GPT3 to expand claims into hallucinated abstracts and then searching based on those abstracts works (claims -> abstracts -> results) better than searching directly (claims -> results)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First, we use in-context examples to prompt the model to generate documents similar to what's in the corpus, for each claim we want to assess. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_hallucination_prompt(claim):\n",
    "    return [{'role': 'system', 'content': \"\"\"I will ask you to write an abstract for a scientific paper which supports or refutes a given claim. It should be written in scientific language, include a title. Output only one abstract, then stop.\n",
    "    \n",
    "    An Example:\n",
    "\n",
    "    Claim:\n",
    "    A high microerythrocyte count raises vulnerability to severe anemia in homozygous alpha (+)- thalassemia trait subjects.\n",
    "\n",
    "    Abstract:\n",
    "    BACKGROUND The heritable haemoglobinopathy alpha(+)-thalassaemia is caused by the reduced synthesis of alpha-globin chains that form part of normal adult haemoglobin (Hb). Individuals homozygous for alpha(+)-thalassaemia have microcytosis and an increased erythrocyte count. Alpha(+)-thalassaemia homozygosity confers considerable protection against severe malaria, including severe malarial anaemia (SMA) (Hb concentration < 50 g/l), but does not influence parasite count. We tested the hypothesis that the erythrocyte indices associated with alpha(+)-thalassaemia homozygosity provide a haematological benefit during acute malaria.   \n",
    "    METHODS AND FINDINGS Data from children living on the north coast of Papua New Guinea who had participated in a case-control study of the protection afforded by alpha(+)-thalassaemia against severe malaria were reanalysed to assess the genotype-specific reduction in erythrocyte count and Hb levels associated with acute malarial disease. We observed a reduction in median erythrocyte count of approximately 1.5 x 10(12)/l in all children with acute falciparum malaria relative to values in community children (p < 0.001). We developed a simple mathematical model of the linear relationship between Hb concentration and erythrocyte count. This model predicted that children homozygous for alpha(+)-thalassaemia lose less Hb than children of normal genotype for a reduction in erythrocyte count of >1.1 x 10(12)/l as a result of the reduced mean cell Hb in homozygous alpha(+)-thalassaemia. In addition, children homozygous for alpha(+)-thalassaemia require a 10% greater reduction in erythrocyte count than children of normal genotype (p = 0.02) for Hb concentration to fall to 50 g/l, the cutoff for SMA. We estimated that the haematological profile in children homozygous for alpha(+)-thalassaemia reduces the risk of SMA during acute malaria compared to children of normal genotype (relative risk 0.52; 95% confidence interval [CI] 0.24-1.12, p = 0.09).   \n",
    "    CONCLUSIONS The increased erythrocyte count and microcytosis in children homozygous for alpha(+)-thalassaemia may contribute substantially to their protection against SMA. A lower concentration of Hb per erythrocyte and a larger population of erythrocytes may be a biologically advantageous strategy against the significant reduction in erythrocyte count that occurs during acute infection with the malaria parasite Plasmodium falciparum. This haematological profile may reduce the risk of anaemia by other Plasmodium species, as well as other causes of anaemia. Other host polymorphisms that induce an increased erythrocyte count and microcytosis may confer a similar advantage.\n",
    "\n",
    "    End of example. \n",
    "    \n",
    "    \"\"\"}, {'role': 'user', 'content': f\"\"\"\"\n",
    "    Perform the task for the following claim.\n",
    "\n",
    "    Claim:\n",
    "    {claim}\n",
    "\n",
    "    Abstract:\n",
    "    \"\"\"}]\n",
    "\n",
    "\n",
    "def hallucinate_evidence(claims):\n",
    "    # Query the OpenAI API\n",
    "    responses = []\n",
    "    # Query the OpenAI API\n",
    "    for claim in claims:\n",
    "        response = openai.ChatCompletion.create(\n",
    "            model='gpt-3.5-turbo',\n",
    "            messages=build_hallucination_prompt(claim),\n",
    "        )\n",
    "        responses.append(response.choices[0].message.content)\n",
    "    return responses"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We hallucinate a document for each claim.\n",
    "\n",
    "*NB: This can take a while, about 30m for 100 claims*. You can reduce the number of claims we want to assess to get results more quickly. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "hallucinated_evidence = hallucinate_evidence(claims)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We use the hallucinated documents as queries into the corpus, and filter the results using the same distance threshold. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "hallucinated_query_result = scifact_corpus_collection.query(query_texts=hallucinated_evidence, include=['documents', 'distances'], n_results=3)\n",
    "filtered_hallucinated_query_result = filter_query_result(hallucinated_query_result)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We then ask the model to assess the claims, using the new context. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\tGroundtruth\n",
      "\tTrue\tFalse\tNEE\n",
      "True\t15\t2\t5\t\n",
      "False\t1\t5\t4\t\n",
      "NEE\t2\t3\t13\t\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'True': {'True': 15, 'False': 2, 'NEE': 5},\n",
       " 'False': {'True': 1, 'False': 5, 'NEE': 4},\n",
       " 'NEE': {'True': 2, 'False': 3, 'NEE': 13}}"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gpt_with_hallucinated_context_evaluation = assess_claims_with_context(claims, filtered_hallucinated_query_result['documents'])\n",
    "confusion_matrix(gpt_with_hallucinated_context_evaluation, groundtruth)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Results\n",
    "\n",
    "Combining HyDE with a simple distance threshold leads to a significant improvement. The model no longer biases assessing claims as True, nor toward their not being enough evidence. It also correctly assesses when there isn't enough evidence more often."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Conclusion\n",
    "\n",
    "Equipping LLMs with a context based on a corpus of documents is a powerful technique for bringing the general reasoning and natural language interactions of LLMs to your own data. However, it's important to know that naive query and retrieval may not produce the best possible results! Ultimately understanding the data will help get the most out of the retrieval based question-answering approach. \n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  },
  "vscode": {
   "interpreter": {
    "hash": "fd16a328ca3d68029457069b79cb0b38eb39a0f5ccc4fe4473d3047707df8207"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
